from sklearn.svm import SVC
import read_mnist_data as data


# train_images是一个ndarray, shape = 60000 * 28 * 28
train_images = data.load_train_images()
train_images = train_images.reshape([60000, 784])
# train_labels是一个ndarray， shape = 60000
train_labels = data.load_train_labels()

test_images = data.load_test_images()
test_labels = data.load_test_labels()

clf = SVC()
clf.fit(train_images, train_labels)

print(clf.predict(test_images[:10].reshape([10, 784])))
print(test_labels[:10])

num = 0
for i in range(len(test_labels)):
    predict = clf.predict(test_images[i].reshape([1, 784]))
    label = test_labels[i]
    if predict == label:
        num += 1

print("acc = ", num / len(test_labels))

